Exploratory Image AnalysisΒΆ
To familiarize ourselves with the dataset, we explore basic statistics and class characteristics of the dataset to evaluate it quantitatively and qualitatively.
We examine visual features that YOLOv8 relies on at low-level (edges, textures, color), mid-level (object parts such as beards and hats), and high-level (humanoid structure and redβwhite gestalt). This helps us identify potential biases and confusion sources, particularly with Santa-like hard negatives (e.g. the Grinch).
Guidelines for the project
data inspection: how many samples, classes, labels, class imbalances
object classes, distributions, statistics, imbalanes, bias .. consider when training, weighted loss
qualitative and quantitative sense of data
dicuss dataset challenges and how to overcome them
visualize image labels visualize bounding boxes
Connect to Gdrive and load the Roboflow datasetΒΆ
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)
%cd /content/drive/My Drive/FHNW/HS_25/DLBS/minichallenge_hs25_object_detection/MC
ERROR:root:Internal Python error in the inspect module. Below is the traceback from this internal error. ERROR:root:Internal Python error in the inspect module. Below is the traceback from this internal error. ERROR:root:Internal Python error in the inspect module. Below is the traceback from this internal error.
Mounted at /content/drive/
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipython-input-1413283218.py", line 4, in <cell line: 0>
get_ipython().run_line_magic('cd', '/content/drive/My Drive/FHNW/HS_25/DLBS/minichallenge_hs25_object_detection/MC')
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2418, in run_line_magic
result = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "<decorator-gen-85>", line 2, in cd
File "/usr/local/lib/python3.12/dist-packages/IPython/core/magic.py", line 187, in <lambda>
call = lambda f, *a, **k: f(*a, **k)
^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/magics/osm.py", line 342, in cd
oldcwd = os.getcwd()
^^^^^^^^^^^
OSError: [Errno 107] Transport endpoint is not connected
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2099, in showtraceback
stb = value._render_traceback_()
^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'OSError' object has no attribute '_render_traceback_'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1101, in get_records
return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 248, in wrapped
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 281, in _fixed_getinnerframes
records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1769, in getinnerframes
traceback_info = getframeinfo(tb, context)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1714, in getframeinfo
filename = getsourcefile(frame) or getfile(frame)
^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 970, in getsourcefile
module = getmodule(object, filename)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 999, in getmodule
file = getabsfile(object, _filename)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 983, in getabsfile
return os.path.normcase(os.path.abspath(_filename))
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<frozen posixpath>", line 415, in abspath
OSError: [Errno 107] Transport endpoint is not connected
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipython-input-1413283218.py", line 4, in <cell line: 0>
get_ipython().run_line_magic('cd', '/content/drive/My Drive/FHNW/HS_25/DLBS/minichallenge_hs25_object_detection/MC')
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2418, in run_line_magic
result = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "<decorator-gen-85>", line 2, in cd
File "/usr/local/lib/python3.12/dist-packages/IPython/core/magic.py", line 187, in <lambda>
call = lambda f, *a, **k: f(*a, **k)
^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/magics/osm.py", line 342, in cd
oldcwd = os.getcwd()
^^^^^^^^^^^
OSError: [Errno 107] Transport endpoint is not connected
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2099, in showtraceback
stb = value._render_traceback_()
^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'OSError' object has no attribute '_render_traceback_'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
if (await self.run_code(code, result, async_=asy)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3575, in run_code
self.showtraceback(running_compiled_code=True)
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2101, in showtraceback
stb = self.InteractiveTB.structured_traceback(etype,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1367, in structured_traceback
return FormattedTB.structured_traceback(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1267, in structured_traceback
return VerboseTB.structured_traceback(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1124, in structured_traceback
formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1082, in format_exception_as_a_whole
last_unique, recursion_repeat = find_recursion(orig_etype, evalue, records)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 382, in find_recursion
return len(records), 0
^^^^^^^^^^^^
TypeError: object of type 'NoneType' has no len()
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2099, in showtraceback
stb = value._render_traceback_()
^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'TypeError' object has no attribute '_render_traceback_'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1101, in get_records
return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 248, in wrapped
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 281, in _fixed_getinnerframes
records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1769, in getinnerframes
traceback_info = getframeinfo(tb, context)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1714, in getframeinfo
filename = getsourcefile(frame) or getfile(frame)
^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 970, in getsourcefile
module = getmodule(object, filename)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 999, in getmodule
file = getabsfile(object, _filename)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 983, in getabsfile
return os.path.normcase(os.path.abspath(_filename))
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<frozen posixpath>", line 415, in abspath
OSError: [Errno 107] Transport endpoint is not connected
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipython-input-1413283218.py", line 4, in <cell line: 0>
get_ipython().run_line_magic('cd', '/content/drive/My Drive/FHNW/HS_25/DLBS/minichallenge_hs25_object_detection/MC')
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2418, in run_line_magic
result = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "<decorator-gen-85>", line 2, in cd
File "/usr/local/lib/python3.12/dist-packages/IPython/core/magic.py", line 187, in <lambda>
call = lambda f, *a, **k: f(*a, **k)
^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/magics/osm.py", line 342, in cd
oldcwd = os.getcwd()
^^^^^^^^^^^
OSError: [Errno 107] Transport endpoint is not connected
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2099, in showtraceback
stb = value._render_traceback_()
^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'OSError' object has no attribute '_render_traceback_'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
if (await self.run_code(code, result, async_=asy)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3575, in run_code
self.showtraceback(running_compiled_code=True)
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2101, in showtraceback
stb = self.InteractiveTB.structured_traceback(etype,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1367, in structured_traceback
return FormattedTB.structured_traceback(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1267, in structured_traceback
return VerboseTB.structured_traceback(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1124, in structured_traceback
formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1082, in format_exception_as_a_whole
last_unique, recursion_repeat = find_recursion(orig_etype, evalue, records)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 382, in find_recursion
return len(records), 0
^^^^^^^^^^^^
TypeError: object of type 'NoneType' has no len()
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2099, in showtraceback
stb = value._render_traceback_()
^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'TypeError' object has no attribute '_render_traceback_'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
return runner(coro)
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
coro.send(None)
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3492, in run_ast_nodes
self.showtraceback()
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2101, in showtraceback
stb = self.InteractiveTB.structured_traceback(etype,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1367, in structured_traceback
return FormattedTB.structured_traceback(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1267, in structured_traceback
return VerboseTB.structured_traceback(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1142, in structured_traceback
formatted_exceptions += self.format_exception_as_a_whole(etype, evalue, etb, lines_of_context,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1082, in format_exception_as_a_whole
last_unique, recursion_repeat = find_recursion(orig_etype, evalue, records)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 382, in find_recursion
return len(records), 0
^^^^^^^^^^^^
TypeError: object of type 'NoneType' has no len()
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2099, in showtraceback
stb = value._render_traceback_()
^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'TypeError' object has no attribute '_render_traceback_'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1101, in get_records
return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 248, in wrapped
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 281, in _fixed_getinnerframes
records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1769, in getinnerframes
traceback_info = getframeinfo(tb, context)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 1714, in getframeinfo
filename = getsourcefile(frame) or getfile(frame)
^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 970, in getsourcefile
module = getmodule(object, filename)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 999, in getmodule
file = getabsfile(object, _filename)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 983, in getabsfile
return os.path.normcase(os.path.abspath(_filename))
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<frozen posixpath>", line 415, in abspath
OSError: [Errno 107] Transport endpoint is not connected
from helpers import (
load_roboflow_data
analyze_dataset,
visualize_images,
)
from google.colab import userdata
api_key = userdata.get('ROBOFLOW_API_KEY')
!pip install roboflow
from roboflow import Roboflow
rf = Roboflow(api_key=api_key)
project = rf.workspace('dlbs-xi5zk').project('santa-qqpxm')
version = project.version(10)
dataset = version.download('yolov8')
Requirement already satisfied: roboflow in /usr/local/lib/python3.12/dist-packages (1.2.11) Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from roboflow) (2025.11.12) Requirement already satisfied: idna==3.7 in /usr/local/lib/python3.12/dist-packages (from roboflow) (3.7) Requirement already satisfied: cycler in /usr/local/lib/python3.12/dist-packages (from roboflow) (0.12.1) Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from roboflow) (1.4.9) Requirement already satisfied: matplotlib in /usr/local/lib/python3.12/dist-packages (from roboflow) (3.10.0) Requirement already satisfied: numpy>=1.18.5 in /usr/local/lib/python3.12/dist-packages (from roboflow) (2.0.2) Requirement already satisfied: opencv-python-headless==4.10.0.84 in /usr/local/lib/python3.12/dist-packages (from roboflow) (4.10.0.84) Requirement already satisfied: Pillow>=7.1.2 in /usr/local/lib/python3.12/dist-packages (from roboflow) (11.3.0) Requirement already satisfied: pi-heif<2 in /usr/local/lib/python3.12/dist-packages (from roboflow) (1.1.1) Requirement already satisfied: pillow-avif-plugin<2 in /usr/local/lib/python3.12/dist-packages (from roboflow) (1.5.2) Requirement already satisfied: python-dateutil in /usr/local/lib/python3.12/dist-packages (from roboflow) (2.9.0.post0) Requirement already satisfied: python-dotenv in /usr/local/lib/python3.12/dist-packages (from roboflow) (1.2.1) Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from roboflow) (2.32.4) Requirement already satisfied: six in /usr/local/lib/python3.12/dist-packages (from roboflow) (1.17.0) Requirement already satisfied: urllib3>=1.26.6 in /usr/local/lib/python3.12/dist-packages (from roboflow) (2.5.0) Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.12/dist-packages (from roboflow) (4.67.1) Requirement already satisfied: PyYAML>=5.3.1 in /usr/local/lib/python3.12/dist-packages (from roboflow) (6.0.3) Requirement already satisfied: requests-toolbelt in /usr/local/lib/python3.12/dist-packages (from roboflow) (1.0.0) Requirement already satisfied: filetype in /usr/local/lib/python3.12/dist-packages (from roboflow) (1.2.0) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->roboflow) (1.3.3) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->roboflow) (4.61.1) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->roboflow) (25.0) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->roboflow) (3.2.5) Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->roboflow) (3.4.4) loading Roboflow workspace... loading Roboflow project...
Check Dataset Content and StructureΒΆ
We inspect images qualitatively, perform class counts and annotation consistency.
import os
# Content of the dataset directory
dataset_path = dataset.location
split_counts = {}
# List all items
for item in sorted(os.listdir(dataset_path)):
item_path = os.path.join(dataset_path, item)
if os.path.isdir(item_path):
if item in ['train', 'valid', 'test']:
image_dir = os.path.join(item_path, 'images')
num_files = len(os.listdir(image_dir))
split_counts[item] = num_files
# Calculate and display statistics
total = sum(split_counts.values())
print(f"\nπ \nTotal Images (Train + Valid + Test): {total}")
print("Count per split:")
for split in ['train', 'valid', 'test']:
count = split_counts.get(split, 0)
pct = (count / total * 100) if total > 0 else 0
print(f"- {split.capitalize()}: {count} images ({pct:.1f}%)")
# Load and display the data.yaml file
import yaml
yaml_path = os.path.join(dataset_path, "data.yaml")
if os.path.exists(yaml_path):
with open(yaml_path, 'r') as f:
data_config = yaml.safe_load(f)
print(f"\nπ \nDataset configuration:")
print(f"- Classes: {data_config.get('names', [])}")
print(f"- Number of classes: {data_config.get('nc', 'N/A')}")
print(f"- Train images: {data_config.get('train', 'N/A')}")
print(f"- Val images: {data_config.get('val', 'N/A')}")
print(f"- Test images: {data_config.get('test', 'N/A')}")
else:
print(f"\nβ οΈ data.yaml not found at {yaml_path}")
π Total Images (Train + Valid + Test): 950 Count per split: - Train: 700 images (73.7%) - Valid: 155 images (16.3%) - Test: 95 images (10.0%) π Dataset configuration: - Classes: ['Santa'] - Number of classes: 1 - Train images: ../train/images - Val images: ../valid/images - Test images: ../test/images
# Usage:
analyze_dataset(dataset.location)
π Total: 950 | Train: 700 (73.7%) | Valid: 155 (16.3%) | Test: 95 (10.0%) π Classes: ['Santa'] | NC: 1
{'train': 700, 'valid': 155, 'test': 95}
# confirm the images and labels folders exist
!ls -F Santa-10/train/
images/ labels/
# Check images
!ls -F Santa-10/test/images | head -n 5
0-38862400_1671944344_santa_jpg.rf.55c6e9ce609b22e4451a311482828cd1.jpg 102_Santa_jpg.rf.8f184416dd01520cdc916a754e3001e6.jpg 146_Santa_jpg.rf.194951c92edbafce3d8144c398a907f5.jpg 156_Santa_jpg.rf.6ae01fa7adaefdd768712c55de29f1d5.jpg 174_Santa_jpg.rf.95ae299d0f0ef1d76006a4863c8e4b3e.jpg
# Check their annotations - same name, but the file ending is .txt
!ls -F Santa-10/test/labels/ | head -n 5
0-38862400_1671944344_santa_jpg.rf.55c6e9ce609b22e4451a311482828cd1.txt 102_Santa_jpg.rf.8f184416dd01520cdc916a754e3001e6.txt 146_Santa_jpg.rf.194951c92edbafce3d8144c398a907f5.txt 156_Santa_jpg.rf.6ae01fa7adaefdd768712c55de29f1d5.txt 174_Santa_jpg.rf.95ae299d0f0ef1d76006a4863c8e4b3e.txt
Visualize images with bounding boxesΒΆ
DATASET_PATH = './Santa-10'
Bounding boxesΒΆ
!pip install supervision -q
import cv2
import supervision as sv
import glob
import random
import os
import numpy as np
import matplotlib.pyplot as plt
CLASS_MAP = {0: 'Santa'}
def load_yolo_detections(label_path, W, H):
detections = None
if os.path.exists(label_path):
boxes = []
class_ids = []
with open(label_path, 'r') as f:
for line in f:
parts = line.split()
if len(parts) < 5:
continue
c, x, y, w, h = map(float, parts)
x1 = int((x - w/2) * W)
y1 = int((y - h/2) * H)
x2 = int((x + w/2) * W)
y2 = int((y + h/2) * H)
boxes.append([x1, y1, x2, y2])
class_ids.append(int(c))
if boxes:
detections = sv.Detections(
xyxy=np.array(boxes),
class_id=np.array(class_ids)
)
return detections
def visualize_images(image_paths, split_name):
if not image_paths:
print(f'No images found in {split_name}.')
return
sample = random.sample(image_paths, min(10, len(image_paths)))
fig, axes = plt.subplots(5, 2, figsize=(16, 30))
fig.suptitle(f"{split_name.capitalize()} Split - {len(sample)} Images", fontsize=16)
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator()
for i, path in enumerate(sample):
img = cv2.imread(path)
if img is None:
continue
H, W, _ = img.shape
label_path = path.replace('images', 'labels').rsplit('.', 1)[0] + '.txt'
detections = load_yolo_detections(label_path, W, H)
if detections:
img = box_annotator.annotate(scene=img, detections=detections)
labels = [CLASS_MAP.get(int(cid), f'Class {cid}') for cid in detections.class_id]
img = label_annotator.annotate(scene=img, detections=detections, labels=labels)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
image_id = os.path.splitext(os.path.basename(path))[0][:30] # Truncate to 30 chars
count = len(detections) if detections else 0
row, col = i // 2, i % 2
axes[row, col].imshow(img_rgb)
axes[row, col].set_title(f'{image_id} ({W}x{H}) - {count} boxes', fontsize=10)
axes[row, col].axis('off')
plt.tight_layout()
plt.subplots_adjust(top=0.96)
plt.show()
for split in ['train', 'test', 'valid']:
image_paths = [p for p in glob.glob(f'{DATASET_PATH}/{split}/images/*')
if p.lower().endswith(('.jpg', '.png'))]
visualize_images(image_paths, split)
print('---')
Output hidden; open in https://colab.research.google.com to view.
TODOΒΆ
- Comment on the images, if they are a good example for variety to train a robust model.
- Fix labels, where needed
Desired result: the bounding boxes are accurately placed and the associated class labels (Santa) is correct. We checked for images with missing annotations (false negatives) or mislabeled objects.
Annotation heatmapΒΆ
Shows where most of the annotations are. Color gradients signify the number of annotations per grid cell.
TO DO: Resize images to 500x500 all, with bboxes, before making the heatmap. So maybe make the heatmap when we are done plotting distribution of image sizes in px and visualizing image examples and canny edgesΒΆ
import os
import cv2
import glob
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
# --- PATHS ---
DATASET_PATH = 'Santa-8/train'
IMAGE_DIR = os.path.join(DATASET_PATH, 'images')
LABEL_DIR = os.path.join(DATASET_PATH, 'labels')
# --- LOAD ALL IMAGES AND LABELS ---
image_paths = glob.glob(os.path.join(IMAGE_DIR, '*.jpg'))
print(f"Found {len(image_paths)} images\n")
# --- CALCULATE AVERAGE IMAGE DIMENSIONS ---
heights, widths = [], []
for img_path in image_paths:
img = cv2.imread(img_path)
if img is not None:
h, w, _ = img.shape
heights.append(h)
widths.append(w)
H = int(np.mean(heights))
W = int(np.mean(widths))
print(f"Average image dimensions: {W}x{H}")
print(f"Height range: {min(heights)} - {max(heights)}")
print(f"Width range: {min(widths)} - {max(widths)}\n")
# Create heatmap
heatmap = np.zeros((H, W), dtype=np.float32)
# --- ACCUMULATE BBOXES INTO HEATMAP ---
bbox_count = 0
for img_path in image_paths:
img = cv2.imread(img_path)
if img is None:
continue
img_h, img_w, _ = img.shape
# Get corresponding label file
label_path = img_path.replace('images', 'labels').rsplit('.', 1)[0] + '.txt'
if not os.path.exists(label_path):
continue
# Read bboxes
with open(label_path, 'r') as f:
for line in f:
parts = line.split()
if len(parts) < 5:
continue
class_id, x, y, w, h = map(float, parts)
# Convert normalized YOLO coords to pixel coords (for original image)
x1 = int((x - w/2) * img_w)
y1 = int((y - h/2) * img_h)
x2 = int((x + w/2) * img_w)
y2 = int((y + h/2) * img_h)
# Scale bbox to average dimensions
x1_scaled = int(x1 * W / img_w)
y1_scaled = int(y1 * H / img_h)
x2_scaled = int(x2 * W / img_w)
y2_scaled = int(y2 * H / img_h)
# Clip to heatmap bounds
x1_scaled = max(0, x1_scaled)
y1_scaled = max(0, y1_scaled)
x2_scaled = min(W, x2_scaled)
y2_scaled = min(H, y2_scaled)
# Add bbox area to heatmap
heatmap[y1_scaled:y2_scaled, x1_scaled:x2_scaled] += 1
bbox_count += 1
print(f"Total bboxes: {bbox_count}")
print(f"Heatmap range: {heatmap.min():.1f} - {heatmap.max():.1f}\n")
# --- CALCULATE STATISTICS ---
q1 = np.quantile(heatmap, 0.25)
median = np.quantile(heatmap, 0.50)
q3 = np.quantile(heatmap, 0.75)
min_val = heatmap.min()
max_val = heatmap.max()
stats_text = f"""# of Annotations Per Grid
Min: {int(min_val)}
Q1: {int(q1)}
Median: {int(median)}
Q3: {int(q3)}
Max: {int(max_val)}"""
# --- VISUALIZE HEATMAP ---
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
fig.suptitle(f'Annotation Density Heatmap - Train Split ({bbox_count} bboxes)', fontsize=14)
# Heatmap
im = axes[0].imshow(heatmap, cmap='hot', interpolation='bilinear')
axes[0].set_title('Bbox Density (darker = more annotations)', fontsize=12)
axes[0].set_xlabel('Width')
axes[0].set_ylabel('Height')
cbar = plt.colorbar(im, ax=axes[0], label='Count')
# Add text box with statistics
axes[0].text(0.02, 0.98, stats_text, transform=axes[0].transAxes,
fontsize=10, verticalalignment='top', fontfamily='monospace',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
# Contour plot
axes[1].contourf(heatmap, levels=15, cmap='hot')
axes[1].set_title('Annotation Contours', fontsize=12)
axes[1].set_xlabel('Width')
axes[1].set_ylabel('Height')
plt.tight_layout()
plt.show()
# --- STATISTICS ---
print("\nπ Annotation Statistics:")
print(stats_text)
print(f"\nImage dimensions: {W}x{H}")
print(f"Mean annotations per pixel: {heatmap.mean():.4f}")
print(f"Coverage: {(heatmap > 0).sum() / (H * W) * 100:.1f}% of image")
Found 590 images Average image dimensions: 500x500 Height range: 500 - 500 Width range: 500 - 500 Total bboxes: 343 Heatmap range: 4.0 - 272.0
π Annotation Statistics: # of Annotations Per Grid Min: 4 Q1: 91 Median: 156 Q3: 220 Max: 272 Image dimensions: 500x500 Mean annotations per pixel: 153.7937 Coverage: 100.0% of image
Distribution of labelsΒΆ
IMAGE_EXT = '*.jpg'
LABEL_EXT = '*.txt'
TARGET_CLASS = 0
# Get all images
image_paths = glob.glob(os.path.join(DATASET_PATH, 'images', IMAGE_EXT))
label_paths = [os.path.join(DATASET_PATH, 'labels', os.path.basename(p).rsplit('.',1)[0] + '.txt')
for p in image_paths]
total_images = len(image_paths)
santa_object_count = 0
images_with_santa = 0
images_without_santa = 0
# Iterate over all labels
for lbl_path in label_paths:
if os.path.exists(lbl_path):
lines = open(lbl_path, 'r').read().splitlines()
class_ids = [int(line.split()[0]) for line in lines if line.strip()]
santa_count_in_image = class_ids.count(0) # Assuming class 0 = Santa
santa_object_count += santa_count_in_image
if santa_count_in_image > 0:
images_with_santa += 1
else:
images_without_santa += 1
else:
images_without_santa += 1 # No label file = background only
# Results
print(f'Total images: {total_images}')
print(f'Total Santa objects (positive examples): {santa_object_count}')
print(f'Images containing Santa: {images_with_santa}')
print(f'Images without Santa: {images_without_santa}')
print(f'Percentage of images with Santa: {images_with_santa / total_images * 100:.2f}%')
print(f'Percentage of background-only images: {images_without_santa / total_images * 100:.2f}%')
Total images: 590 Total Santa objects (positive examples): 366 Images containing Santa: 335 Images without Santa: 255 Percentage of images with Santa: 56.78% Percentage of background-only images: 43.22%
335 is the total number of bounding boxes labeled as Santa across the dataset. Some images contain more than one Santa. There are 314 images that contain at least one Santa.
# Images with more than one bounding box
multi_bbox_images = []
multi_bbox_labels = []
# Find images with more than one bounding box
for img_path, lbl_path in zip(image_paths, label_paths):
if os.path.exists(lbl_path):
lines = open(lbl_path, 'r').read().splitlines()
bbox_count = len([line for line in lines if line.strip()])
if bbox_count > 1:
multi_bbox_images.append(img_path)
multi_bbox_labels.append(lbl_path)
print(f"Total images: {len(image_paths)}")
print(f"Images with >1 bounding box: {len(multi_bbox_images)}\n")
Total images: 590 Images with >1 bounding box: 17
# Visualize multi-bbox images
sample_size = min(10, len(multi_bbox_images))
sample_indices = np.random.choice(len(multi_bbox_images), sample_size, replace=False)
fig, axes = plt.subplots(5, 2, figsize=(14, 18))
fig.suptitle(f"Images with Multiple Bounding Boxes (showing {sample_size})", fontsize=16)
for idx, sample_idx in enumerate(sample_indices):
img_path = multi_bbox_images[sample_idx]
lbl_path = multi_bbox_labels[sample_idx]
# Load image
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
H, W, _ = img.shape
# Read bounding boxes
with open(lbl_path, 'r') as f:
lines = f.read().splitlines()
bbox_count = len([line for line in lines if line.strip()])
# Draw bounding boxes
for line in lines:
if not line.strip():
continue
parts = line.split()
class_id = int(parts[0])
x, y, w, h = map(float, parts[1:5])
# Convert normalized YOLO coords to pixel coords
x1 = int((x - w/2) * W)
y1 = int((y - h/2) * H)
x2 = int((x + w/2) * W)
y2 = int((y + h/2) * H)
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(img, f'Class {class_id}', (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# Plot
row, col = idx // 2, idx % 2
axes[row, col].imshow(img)
axes[row, col].set_title(f"{os.path.basename(img_path)}\n({bbox_count} boxes)", fontsize=10)
axes[row, col].axis('off')
plt.tight_layout()
plt.show()
Image characteristics: edges, color schemeΒΆ
# Do the images have clear edges?
# Select a random image
# Find all images in the train split and randomly select 3
all_train_images = glob.glob(os.path.join(DATASET_PATH, 'images', IMAGE_EXT))
# Filter images that contain only Santa class ---
santa_images = []
# Loop through all images
for img_path in all_train_images:
# Find the corresponding label file (replace 'images' with 'labels')
label_path = img_path.replace('images', 'labels').rsplit('.', 1)[0] + '.txt'
if os.path.exists(label_path):
# Read the label file
with open(label_path, 'r') as f:
labels = [int(line.split()[0]) for line in f.readlines()]
# Check if the labels are not empty and all labels are "Santa" (class_id == TARGET_CLASS)
if labels and all(label == TARGET_CLASS for label in labels): # All labels must be Santa
santa_images.append(img_path)
# Print the filtered images that only contain "Santa" class
print(f"Found {len(santa_images)} images containing only Santa class.")
# --- Randomly select 3 images with Santa ---
sample_images = random.sample(santa_images, 3)
for img_path in sample_images:
print(f"Image: {os.path.basename(img_path)}")
# Load image
img = cv2.imread(img_path)
# Canny Edge Detection
edges = cv2.Canny(img, 100, 200)
# Convert BGR -> RGB for display
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Plot original and edges
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(img_rgb)
axes[0].set_title('Original Image')
axes[0].axis('off')
axes[1].imshow(edges, cmap='gray')
axes[1].set_title('Canny Edges (100, 200)')
axes[1].axis('off')
plt.show()
Found 312 images containing only Santa class. Image: 588_Santa_jpg.rf.4db1fd8a1ffc7a45d5463ca951572759.jpg
Image: christmas-1903109_1280_jpg.rf.d2fe55a682d00dd853206b6c80efbfbb.jpg
Image: 384_Santa_jpg.rf.7bc450fb07597ade4ab2dae811d8d5d5.jpg
# Find all images and filter for Santa-only
all_train_images = glob.glob(os.path.join(DATASET_PATH, 'images', IMAGE_EXT))
santa_images = []
for img_path in all_train_images:
label_path = img_path.replace('images', 'labels').rsplit('.', 1)[0] + '.txt'
if os.path.exists(label_path):
with open(label_path, 'r') as f:
labels = [int(line.split()[0]) for line in f.readlines()]
if labels and all(label == TARGET_CLASS for label in labels):
santa_images.append(img_path)
print(f"Found {len(santa_images)} images containing only Santa class.\n")
# Select 5 random images and apply Canny edge detection
sample_images = random.sample(santa_images, 5)
# Create figure with 5 rows (one per image) and 2 columns (original, Canny)
fig, axes = plt.subplots(5, 2, figsize=(8,10))
fig.suptitle('Canny Edge Detection', fontsize=16)
for row, img_path in enumerate(sample_images):
img = cv2.imread(img_path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Canny Edge Detection
edges = cv2.Canny(img, 100, 200)
# Original
axes[row, 0].imshow(img_rgb)
axes[row, 0].set_title(f'{os.path.basename(img_path)}', fontsize=10)
axes[row, 0].axis('off')
# Canny
axes[row, 1].imshow(edges, cmap='gray')
axes[row, 1].set_title(f'Canny (100, 200)', fontsize=10)
axes[row, 1].axis('off')
plt.tight_layout()
plt.show()
Found 312 images containing only Santa class.
For Santa Claus images, we want thresholds that are high enough to capture the strong boundaries of the red suit, hat, and belt (the mid-level features) but low enough to capture the edges of the white beard and fur trim without picking up excessive background noise.
If objects have blurry, poor edges we can expect potential training issues, while high-contrast objectswith strong edges are easier for model to learn.
import glob
import numpy as np
import matplotlib.pyplot as plt
# --- PATHS ---
IMAGE_DIR = os.path.join(DATASET_PATH, 'images')
LABEL_DIR = os.path.join(DATASET_PATH, 'labels')
# --- COUNT OBJECTS PER IMAGE ---
image_paths = glob.glob(os.path.join(IMAGE_DIR, '*.jpg'))
object_counts = []
images_without_objects = 0
for img_path in image_paths:
label_path = img_path.replace('images', 'labels').rsplit('.', 1)[0] + '.txt'
count = 0
if os.path.exists(label_path):
with open(label_path, 'r') as f:
count = len([line for line in f if line.strip()])
object_counts.append(count)
if count == 0:
images_without_objects += 1
object_counts = np.array(object_counts)
# --- STATISTICS ---
stats = {
'Total Images': len(image_paths),
'Images without objects': images_without_objects,
'Images with objects': len(image_paths) - images_without_objects,
'Total objects': int(object_counts.sum()),
'Mean objects/image': f"{object_counts.mean():.2f}",
'Median objects/image': int(np.median(object_counts)),
'Min objects': int(object_counts.min()),
'Max objects': int(object_counts.max()),
'Std Dev': f"{object_counts.std():.2f}"
}
print("π Object Count Statistics:")
print("-" * 40)
for key, val in stats.items():
print(f"{key:.<30} {val}")
print()
# --- HISTOGRAM ---
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Object Distribution Analysis - Train Split', fontsize=16, fontweight='bold')
# 1. Histogram of object counts
ax = axes[0, 0]
bins = range(0, int(object_counts.max()) + 2)
ax.hist(object_counts, bins=bins, color='steelblue', edgecolor='black', alpha=0.7)
ax.axvline(object_counts.mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {object_counts.mean():.2f}')
ax.axvline(np.median(object_counts), color='green', linestyle='--', linewidth=2, label=f'Median: {int(np.median(object_counts))}')
ax.set_xlabel('Objects per Image', fontsize=11)
ax.set_ylabel('Number of Images', fontsize=11)
ax.set_title('Histogram: Objects per Image', fontsize=12, fontweight='bold')
ax.legend()
ax.grid(alpha=0.3)
# 2. Box plot
ax = axes[0, 1]
bp = ax.boxplot(object_counts, vert=True, patch_artist=True)
bp['boxes'][0].set_facecolor('lightblue')
ax.set_ylabel('Objects per Image', fontsize=11)
ax.set_title('Box Plot: Object Distribution', fontsize=12, fontweight='bold')
ax.grid(alpha=0.3, axis='y')
# Add statistics text on box plot
textstr = f"Q1: {int(np.quantile(object_counts, 0.25))}\nMedian: {int(np.median(object_counts))}\nQ3: {int(np.quantile(object_counts, 0.75))}"
ax.text(1.15, object_counts.mean(), textstr, fontsize=10,
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
# 3. Cumulative distribution
ax = axes[1, 0]
sorted_counts = np.sort(object_counts)
cumulative = np.arange(1, len(sorted_counts) + 1) / len(sorted_counts) * 100
ax.plot(sorted_counts, cumulative, marker='o', linewidth=2, markersize=4, color='steelblue')
ax.fill_between(sorted_counts, cumulative, alpha=0.3, color='steelblue')
ax.set_xlabel('Objects per Image', fontsize=11)
ax.set_ylabel('Cumulative % of Images', fontsize=11)
ax.set_title('Cumulative Distribution', fontsize=12, fontweight='bold')
ax.grid(alpha=0.3)
# 4. Class distribution pie chart (if multiple classes)
ax = axes[1, 1]
class_counts = [0] * 10 # Support up to 10 classes
for img_path in image_paths:
label_path = img_path.replace('images', 'labels').rsplit('.', 1)[0] + '.txt'
if os.path.exists(label_path):
with open(label_path, 'r') as f:
for line in f:
parts = line.split()
if parts:
class_id = int(parts[0])
if class_id < len(class_counts):
class_counts[class_id] += 1
class_counts = [c for c in class_counts if c > 0]
class_labels = [f'Class {i}' for i in range(len(class_counts))]
if len(class_counts) > 1:
colors = plt.cm.Set3(range(len(class_counts)))
ax.pie(class_counts, labels=class_labels, autopct='%1.1f%%', colors=colors, startangle=90)
ax.set_title('Class Distribution', fontsize=12, fontweight='bold')
else:
ax.text(0.5, 0.5, f'Single Class Dataset\nTotal Objects: {sum(class_counts)}',
ha='center', va='center', fontsize=14, transform=ax.transAxes)
ax.set_title('Class Distribution', fontsize=12, fontweight='bold')
ax.axis('off')
plt.tight_layout()
plt.show()
# --- DETAILED BREAKDOWN ---
print("\nπ Detailed Breakdown:")
print("-" * 40)
unique, counts = np.unique(object_counts, return_counts=True)
for obj_count, num_images in zip(unique, counts):
pct = (num_images / len(image_paths)) * 100
print(f"{int(obj_count)} object(s): {int(num_images):>3} images ({pct:>5.1f}%)")
π Object Count Statistics: ---------------------------------------- Total Images.................. 590 Images without objects........ 278 Images with objects........... 312 Total objects................. 343 Mean objects/image............ 0.58 Median objects/image.......... 1 Min objects................... 0 Max objects................... 6 Std Dev....................... 0.66
π Detailed Breakdown: ---------------------------------------- 0 object(s): 278 images ( 47.1%) 1 object(s): 295 images ( 50.0%) 2 object(s): 9 images ( 1.5%) 3 object(s): 5 images ( 0.8%) 4 object(s): 1 images ( 0.2%) 5 object(s): 1 images ( 0.2%) 6 object(s): 1 images ( 0.2%)
What we visualized:
- Histogram β Distribution of objects per image (red=mean, green=median)
- Box Plot β Quartiles, outliers, and spread
- Cumulative Distribution β What % of images have β€X objects
- Class Distribution Pie β Breakdown by class (single-class or multi-class)
What to look for:
- Balanced β Similar counts across images (good for training)
- Skewed β Most images have 1 object, few have many (may need data augmentation)
- Empty images β Images with 0 objects (check if intentional negatives)
- Outliers β Images with unusually many objects (harder to train on)
%pwd
'/content/drive/My Drive/FHNW/HS_25/DLBS/minichallenge_hs25_object_detection'
"""
YOLO Dataset Exploratory Data Analysis Tool
Comprehensive visual analysis of YOLO object detection datasets
"""
import os
import yaml
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from pathlib import Path
import cv2
from collections import defaultdict
class YOLODatasetEDA:
"""
Comprehensive EDA for YOLO datasets with visualization.
"""
def __init__(self, yaml_path, dataset_root=None):
"""
Initialize EDA tool.
Parameters:
-----------
yaml_path : str
Path to data.yaml file
dataset_root : str, optional
Root directory of dataset (if yaml paths are relative)
"""
self.yaml_path = yaml_path
# Load YAML
with open(yaml_path, 'r') as f:
self.config = yaml.safe_load(f)
# Set dataset root
if dataset_root is None:
dataset_root = os.path.dirname(yaml_path)
self.dataset_root = dataset_root
# Get class names
self.class_names = self.config['names']
self.num_classes = self.config['nc']
print(f"π Dataset loaded: {self.num_classes} class(es)")
print(f" Classes: {self.class_names}")
def _get_split_paths(self, split):
"""Get image and label paths for a split."""
# Handle relative paths
img_path = self.config[split]
if img_path.startswith('..'):
img_path = os.path.join(self.dataset_root, img_path.lstrip('../'))
# Get labels path (replace /images with /labels)
label_path = img_path.replace('/images', '/labels')
return img_path, label_path
def analyze_dataset_splits(self):
"""
Analyze dataset splits: train, val, test.
Returns statistics for each split.
"""
splits_data = {}
for split in ['train', 'val', 'test']:
if split not in self.config:
continue
img_path, label_path = self._get_split_paths(split)
if not os.path.exists(img_path):
print(f"β οΈ Warning: {split} images not found at {img_path}")
continue
# Count images
image_files = [f for f in os.listdir(img_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
num_images = len(image_files)
# Count annotations
total_annotations = 0
class_counts = defaultdict(int)
images_with_annotations = 0
if os.path.exists(label_path):
for img_file in image_files:
label_file = os.path.splitext(img_file)[0] + '.txt'
label_file_path = os.path.join(label_path, label_file)
if os.path.exists(label_file_path):
with open(label_file_path, 'r') as f:
lines = f.readlines()
if lines:
images_with_annotations += 1
for line in lines:
if line.strip():
parts = line.strip().split()
if parts:
cls_id = int(parts[0])
class_counts[cls_id] += 1
total_annotations += 1
splits_data[split] = {
'num_images': num_images,
'total_annotations': total_annotations,
'images_with_annotations': images_with_annotations,
'class_counts': dict(class_counts),
'avg_annotations_per_image': total_annotations / num_images if num_images > 0 else 0
}
print(f"\n{split.upper()} split:")
print(f" Images: {num_images}")
print(f" Total annotations: {total_annotations}")
print(f" Images with annotations: {images_with_annotations}")
print(f" Avg annotations/image: {splits_data[split]['avg_annotations_per_image']:.2f}")
return splits_data
def plot_split_statistics(self, splits_data):
"""
Plot 1: Dataset split bar chart with class annotations.
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
splits = list(splits_data.keys())
# Plot 1: Number of images per split
num_images = [splits_data[s]['num_images'] for s in splits]
colors = ['#3498db', '#e74c3c', '#2ecc71']
bars = ax1.bar(splits, num_images, color=colors[:len(splits)],
edgecolor='black', alpha=0.7)
ax1.set_ylabel('Number of Images', fontsize=12)
ax1.set_title('Dataset Split Distribution', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3, axis='y')
# Add value labels on bars
for bar, count in zip(bars, num_images):
height = bar.get_height()
ax1.text(bar.get_x() + bar.get_width()/2., height,
f'{int(count)}',
ha='center', va='bottom', fontsize=12, fontweight='bold')
# Plot 2: Annotations per split (stacked by class)
annotations_by_class = {}
for cls_id in range(self.num_classes):
annotations_by_class[cls_id] = [
splits_data[s]['class_counts'].get(cls_id, 0) for s in splits
]
bottom = np.zeros(len(splits))
colors_classes = plt.cm.Set3(np.linspace(0, 1, self.num_classes))
for cls_id in range(self.num_classes):
class_name = self.class_names[cls_id] if isinstance(self.class_names, list) else self.class_names[cls_id]
counts = annotations_by_class[cls_id]
ax2.bar(splits, counts, bottom=bottom, label=class_name,
color=colors_classes[cls_id], edgecolor='black', alpha=0.8)
# Add value labels
for i, (split_name, count) in enumerate(zip(splits, counts)):
if count > 0:
ax2.text(i, bottom[i] + count/2, str(count),
ha='center', va='center', fontsize=10, fontweight='bold')
bottom += counts
ax2.set_ylabel('Number of Annotations', fontsize=12)
ax2.set_title('Annotations per Split (by Class)', fontsize=14, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()
def analyze_image_dimensions(self):
"""
Analyze image dimensions across all splits.
"""
widths = []
heights = []
aspect_ratios = []
for split in ['train', 'val', 'test']:
if split not in self.config:
continue
img_path, _ = self._get_split_paths(split)
if not os.path.exists(img_path):
continue
image_files = [f for f in os.listdir(img_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
# Sample up to 100 images for speed
sampled_files = np.random.choice(image_files,
min(100, len(image_files)),
replace=False)
for img_file in sampled_files:
img_full_path = os.path.join(img_path, img_file)
try:
with Image.open(img_full_path) as img:
w, h = img.size
widths.append(w)
heights.append(h)
aspect_ratios.append(w / h)
except:
continue
return np.array(widths), np.array(heights), np.array(aspect_ratios)
def plot_image_dimensions(self, widths, heights, aspect_ratios):
"""
Plot 2: Image dimensions and aspect ratios.
"""
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# Scatter plot: width vs height
ax = axes[0]
scatter = ax.scatter(widths, heights, alpha=0.5, s=50, c=aspect_ratios,
cmap='viridis', edgecolors='black', linewidth=0.5)
ax.set_xlabel('Width (pixels)', fontsize=12)
ax.set_ylabel('Height (pixels)', fontsize=12)
ax.set_title('Image Dimensions Distribution', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
# Add average lines
avg_w, avg_h = widths.mean(), heights.mean()
ax.axvline(avg_w, color='red', linestyle='--', linewidth=2,
label=f'Avg W: {avg_w:.0f}')
ax.axhline(avg_h, color='blue', linestyle='--', linewidth=2,
label=f'Avg H: {avg_h:.0f}')
ax.legend()
# Colorbar
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Aspect Ratio (W/H)', fontsize=10)
# Histogram: widths
ax = axes[1]
ax.hist(widths, bins=30, color='skyblue', edgecolor='black', alpha=0.7)
ax.axvline(avg_w, color='red', linestyle='--', linewidth=2,
label=f'Mean: {avg_w:.0f}')
ax.axvline(np.median(widths), color='green', linestyle='--', linewidth=2,
label=f'Median: {np.median(widths):.0f}')
ax.set_xlabel('Width (pixels)', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title('Width Distribution', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')
# Histogram: heights
ax = axes[2]
ax.hist(heights, bins=30, color='lightcoral', edgecolor='black', alpha=0.7)
ax.axvline(avg_h, color='blue', linestyle='--', linewidth=2,
label=f'Mean: {avg_h:.0f}')
ax.axvline(np.median(heights), color='green', linestyle='--', linewidth=2,
label=f'Median: {np.median(heights):.0f}')
ax.set_xlabel('Height (pixels)', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title('Height Distribution', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()
# Print statistics
print(f"\nπ IMAGE DIMENSION STATISTICS:")
print(f" Width: mean={widths.mean():.0f}, median={np.median(widths):.0f}, "
f"std={widths.std():.0f}")
print(f" Height: mean={heights.mean():.0f}, median={np.median(heights):.0f}, "
f"std={heights.std():.0f}")
print(f" Aspect Ratio: mean={aspect_ratios.mean():.2f}, "
f"median={np.median(aspect_ratios):.2f}")
def create_annotation_heatmap(self, grid_size=20):
"""
Plot 3: Heatmap showing where annotations are located in images.
Updated to use a blue-red colormap where:
- Blue = lower frequency,
- Orange = intermediate frequency,
- Red = higher frequency.
"""
# Create grid for heatmap (normalized coordinates 0-1)
heatmap = np.zeros((grid_size, grid_size))
for split in ['train', 'val', 'test']:
if split not in self.config:
continue
img_path, label_path = self._get_split_paths(split)
if not os.path.exists(label_path):
continue
label_files = [f for f in os.listdir(label_path) if f.endswith('.txt')]
for label_file in label_files:
label_file_path = os.path.join(label_path, label_file)
with open(label_file_path, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 5:
# YOLO format: class x_center y_center width height (normalized)
x_center = float(parts[1])
y_center = float(parts[2])
# Map to grid
grid_x = int(x_center * grid_size)
grid_y = int(y_center * grid_size)
# Clamp to valid range
grid_x = max(0, min(grid_size - 1, grid_x))
grid_y = max(0, min(grid_size - 1, grid_y))
heatmap[grid_y, grid_x] += 1
# Define a custom colormap: Blue to Red, with intermediate Orange
cmap = mcolors.LinearSegmentedColormap.from_list(
"blue_red", ["blue", "orange", "red"], N=256
)
# Plot heatmap with the updated colormap
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(heatmap, cmap=cmap, interpolation='bilinear', origin='upper')
ax.set_xlabel('Normalized X Position', fontsize=12)
ax.set_ylabel('Normalized Y Position', fontsize=12)
ax.set_title('Annotation Center Heatmap (All Splits)',
fontsize=14, fontweight='bold')
# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Annotation Density', fontsize=12)
# Add grid
ax.set_xticks(np.arange(0, grid_size, grid_size//5))
ax.set_yticks(np.arange(0, grid_size, grid_size//5))
ax.set_xticklabels([f'{x/grid_size:.1f}' for x in range(0, grid_size, grid_size//5)])
ax.set_yticklabels([f'{y/grid_size:.1f}' for y in range(0, grid_size, grid_size//5)])
ax.grid(True, alpha=0.3, color='white', linewidth=1)
plt.tight_layout()
plt.show()
return heatmap
def analyze_objects_per_image(self):
"""
Analyze distribution of object counts per image.
"""
object_counts = []
for split in ['train', 'val', 'test']:
if split not in self.config:
continue
img_path, label_path = self._get_split_paths(split)
if not os.path.exists(img_path) or not os.path.exists(label_path):
continue
image_files = [f for f in os.listdir(img_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
for img_file in image_files:
label_file = os.path.splitext(img_file)[0] + '.txt'
label_file_path = os.path.join(label_path, label_file)
count = 0
if os.path.exists(label_file_path):
with open(label_file_path, 'r') as f:
count = len([line for line in f if line.strip()])
object_counts.append(count)
return np.array(object_counts)
def plot_objects_per_image(self, object_counts):
"""
Plot 4: Histogram of object count per image.
"""
fig, ax = plt.subplots(figsize=(12, 6))
max_count = object_counts.max()
bins = range(0, max_count + 2)
ax.hist(object_counts, bins=bins, edgecolor='black',
color='mediumseagreen', alpha=0.7)
# Statistics
mean_count = object_counts.mean()
median_count = np.median(object_counts)
ax.axvline(mean_count, color='red', linestyle='--', linewidth=2,
label=f'Mean: {mean_count:.2f}')
ax.axvline(median_count, color='blue', linestyle='--', linewidth=2,
label=f'Median: {median_count:.0f}')
ax.set_xlabel('Number of Objects per Image', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title('Distribution of Object Count per Image',
fontsize=14, fontweight='bold')
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()
# Print statistics
print(f"\nπ¦ OBJECTS PER IMAGE STATISTICS:")
print(f" Mean: {mean_count:.2f}")
print(f" Median: {median_count:.0f}")
print(f" Max: {max_count}")
print(f" Images with 0 objects: {np.sum(object_counts == 0)}")
print(f" Images with 1 object: {np.sum(object_counts == 1)}")
print(f" Images with 2+ objects: {np.sum(object_counts >= 2)}")
def visualize_sample_images(self, num_samples=6, split='train'):
"""
Plot 5: Sample images with bounding boxes.
"""
img_path, label_path = self._get_split_paths(split)
if not os.path.exists(img_path):
print(f"β οΈ {split} images not found")
return
image_files = [f for f in os.listdir(img_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
# Sample random images
sampled_files = np.random.choice(image_files,
min(num_samples, len(image_files)),
replace=False)
# Create grid
rows = 2
cols = 3
fig, axes = plt.subplots(rows, cols, figsize=(18, 12))
axes = axes.flatten()
for idx, img_file in enumerate(sampled_files):
if idx >= rows * cols:
break
# Load image
img_full_path = os.path.join(img_path, img_file)
img = cv2.imread(img_full_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w = img.shape[:2]
# Load annotations
label_file = os.path.splitext(img_file)[0] + '.txt'
label_file_path = os.path.join(label_path, label_file)
if os.path.exists(label_file_path):
with open(label_file_path, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 5:
cls_id = int(parts[0])
x_center = float(parts[1]) * w
y_center = float(parts[2]) * h
box_w = float(parts[3]) * w
box_h = float(parts[4]) * h
# Convert to corner coordinates
x1 = int(x_center - box_w / 2)
y1 = int(y_center - box_h / 2)
x2 = int(x_center + box_w / 2)
y2 = int(y_center + box_h / 2)
# Draw rectangle
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
# Add label
class_name = self.class_names[cls_id] if isinstance(self.class_names, list) else self.class_names[cls_id]
cv2.putText(img, class_name, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
# Display
axes[idx].imshow(img)
axes[idx].set_title(f'{split}: {img_file}', fontsize=10)
axes[idx].axis('off')
# Hide unused subplots
for idx in range(len(sampled_files), rows * cols):
axes[idx].axis('off')
plt.suptitle(f'Sample Images with Annotations ({split.upper()} split)',
fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.show()
def run_full_eda(self):
"""
Run complete EDA pipeline with all visualizations.
"""
print("="*70)
print("π YOLO DATASET EXPLORATORY DATA ANALYSIS")
print("="*70)
# 1. Dataset splits analysis
print("\nπ Analyzing dataset splits...")
splits_data = self.analyze_dataset_splits()
self.plot_split_statistics(splits_data)
# # 2. Image dimensions
# print("\nπ Analyzing image dimensions...")
# widths, heights, aspect_ratios = self.analyze_image_dimensions()
# self.plot_image_dimensions(widths, heights, aspect_ratios)
# 3. Annotation heatmap
print("\nπΊοΈ Creating annotation heatmap...")
self.create_annotation_heatmap(grid_size=20)
# 4. Objects per image
print("\nπ¦ Analyzing objects per image...")
object_counts = self.analyze_objects_per_image()
self.plot_objects_per_image(object_counts)
# 5. Sample images
print("\nπΌοΈ Visualizing sample images...")
for split in ['train', 'val']:
if split in self.config:
print(f"\n {split.upper()} samples:")
self.visualize_sample_images(num_samples=6, split=split)
print("\n" + "="*70)
print("β
EDA COMPLETE!")
print("="*70)
# =======================
# USAGE
# =======================
# Initialize EDA
yaml_path = '/content/drive/My Drive/FHNW/HS_25/DLBS/minichallenge_hs25_object_detection/Santa-9/data.yaml' # Update this path
dataset_root = '/content/drive/My Drive/FHNW/HS_25/DLBS/minichallenge_hs25_object_detection/Santa-9' # Update if needed
eda = YOLODatasetEDA(yaml_path, dataset_root)
# Run full analysis
eda.run_full_eda()
# Or run individual analyses
# splits_data = eda.analyze_dataset_splits()
# eda.plot_split_statistics(splits_data)
# widths, heights, aspect_ratios = eda.analyze_image_dimensions()
# eda.plot_image_dimensions(widths, heights, aspect_ratios)
Output hidden; open in https://colab.research.google.com to view.
"""
YOLO Dataset Exploratory Data Analysis Tool
Comprehensive visual analysis of YOLO object detection datasets
"""
import os
import yaml
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from pathlib import Path
import cv2
from collections import defaultdict
class YOLODatasetEDA:
"""
Comprehensive EDA for YOLO datasets with visualization.
"""
def __init__(self, yaml_path, dataset_root=None):
"""
Initialize EDA tool.
Parameters:
-----------
yaml_path : str
Path to data.yaml file
dataset_root : str, optional
Root directory of dataset (if yaml paths are relative)
"""
self.yaml_path = yaml_path
# Load YAML
with open(yaml_path, 'r') as f:
self.config = yaml.safe_load(f)
# Set dataset root
if dataset_root is None:
dataset_root = os.path.dirname(yaml_path)
self.dataset_root = dataset_root
# Get class names
self.class_names = self.config['names']
self.num_classes = self.config['nc']
print(f"π Dataset loaded: {self.num_classes} class(es)")
print(f" Classes: {self.class_names}")
def _get_split_paths(self, split):
"""Get image and label paths for a split."""
# Handle relative paths
img_path = self.config[split]
if img_path.startswith('..'):
img_path = os.path.join(self.dataset_root, img_path.lstrip('../'))
# Get labels path (replace /images with /labels)
label_path = img_path.replace('/images', '/labels')
return img_path, label_path
def analyze_dataset_splits(self):
"""
Analyze dataset splits: train, val, test.
Returns statistics for each split.
"""
splits_data = {}
for split in ['train', 'val', 'test']:
if split not in self.config:
continue
img_path, label_path = self._get_split_paths(split)
if not os.path.exists(img_path):
print(f"β οΈ Warning: {split} images not found at {img_path}")
continue
# Count images
image_files = [f for f in os.listdir(img_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
num_images = len(image_files)
# Count annotations
total_annotations = 0
class_counts = defaultdict(int)
images_with_annotations = 0
if os.path.exists(label_path):
for img_file in image_files:
label_file = os.path.splitext(img_file)[0] + '.txt'
label_file_path = os.path.join(label_path, label_file)
if os.path.exists(label_file_path):
with open(label_file_path, 'r') as f:
lines = f.readlines()
if lines:
images_with_annotations += 1
for line in lines:
if line.strip():
parts = line.strip().split()
if parts:
cls_id = int(parts[0])
class_counts[cls_id] += 1
total_annotations += 1
splits_data[split] = {
'num_images': num_images,
'total_annotations': total_annotations,
'images_with_annotations': images_with_annotations,
'class_counts': dict(class_counts),
'avg_annotations_per_image': total_annotations / num_images if num_images > 0 else 0
}
print(f"\n{split.upper()} split:")
print(f" Images: {num_images}")
print(f" Total annotations: {total_annotations}")
print(f" Images with annotations: {images_with_annotations}")
print(f" Avg annotations/image: {splits_data[split]['avg_annotations_per_image']:.2f}")
return splits_data
def plot_split_statistics(self, splits_data):
"""
Plot 1: Dataset split bar chart with class annotations.
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
splits = list(splits_data.keys())
# Plot 1: Number of images per split
num_images = [splits_data[s]['num_images'] for s in splits]
colors = ['#3498db', '#e74c3c', '#2ecc71']
bars = ax1.bar(splits, num_images, color=colors[:len(splits)],
edgecolor='black', alpha=0.7)
ax1.set_ylabel('Number of Images', fontsize=12)
ax1.set_title('Dataset Split Distribution', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3, axis='y')
# Add value labels on bars
for bar, count in zip(bars, num_images):
height = bar.get_height()
ax1.text(bar.get_x() + bar.get_width()/2., height,
f'{int(count)}',
ha='center', va='bottom', fontsize=12, fontweight='bold')
# Plot 2: Annotations per split (stacked by class)
annotations_by_class = {}
for cls_id in range(self.num_classes):
annotations_by_class[cls_id] = [
splits_data[s]['class_counts'].get(cls_id, 0) for s in splits
]
bottom = np.zeros(len(splits))
colors_classes = plt.cm.Set3(np.linspace(0, 1, self.num_classes))
for cls_id in range(self.num_classes):
class_name = self.class_names[cls_id] if isinstance(self.class_names, list) else self.class_names[cls_id]
counts = annotations_by_class[cls_id]
ax2.bar(splits, counts, bottom=bottom, label=class_name,
color=colors_classes[cls_id], edgecolor='black', alpha=0.8)
# Add value labels
for i, (split_name, count) in enumerate(zip(splits, counts)):
if count > 0:
ax2.text(i, bottom[i] + count/2, str(count),
ha='center', va='center', fontsize=10, fontweight='bold')
bottom += counts
ax2.set_ylabel('Number of Annotations', fontsize=12)
ax2.set_title('Annotations per Split (by Class)', fontsize=14, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()
def analyze_image_dimensions(self):
"""
Analyze image dimensions across all splits.
"""
widths = []
heights = []
aspect_ratios = []
for split in ['train', 'val', 'test']:
if split not in self.config:
continue
img_path, _ = self._get_split_paths(split)
if not os.path.exists(img_path):
continue
image_files = [f for f in os.listdir(img_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
# Sample up to 100 images for speed
sampled_files = np.random.choice(image_files,
min(100, len(image_files)),
replace=False)
for img_file in sampled_files:
img_full_path = os.path.join(img_path, img_file)
try:
with Image.open(img_full_path) as img:
w, h = img.size
widths.append(w)
heights.append(h)
aspect_ratios.append(w / h)
except:
continue
return np.array(widths), np.array(heights), np.array(aspect_ratios)
def plot_image_dimensions(self, widths, heights, aspect_ratios):
"""
Plot 2: Image dimensions and aspect ratios.
"""
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# Scatter plot: width vs height
ax = axes[0]
scatter = ax.scatter(widths, heights, alpha=0.5, s=50, c=aspect_ratios,
cmap='viridis', edgecolors='black', linewidth=0.5)
ax.set_xlabel('Width (pixels)', fontsize=12)
ax.set_ylabel('Height (pixels)', fontsize=12)
ax.set_title('Image Dimensions Distribution', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
# Add average lines
avg_w, avg_h = widths.mean(), heights.mean()
ax.axvline(avg_w, color='red', linestyle='--', linewidth=2,
label=f'Avg W: {avg_w:.0f}')
ax.axhline(avg_h, color='blue', linestyle='--', linewidth=2,
label=f'Avg H: {avg_h:.0f}')
ax.legend()
# Colorbar
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Aspect Ratio (W/H)', fontsize=10)
# Histogram: widths
ax = axes[1]
ax.hist(widths, bins=30, color='skyblue', edgecolor='black', alpha=0.7)
ax.axvline(avg_w, color='red', linestyle='--', linewidth=2,
label=f'Mean: {avg_w:.0f}')
ax.axvline(np.median(widths), color='green', linestyle='--', linewidth=2,
label=f'Median: {np.median(widths):.0f}')
ax.set_xlabel('Width (pixels)', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title('Width Distribution', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')
# Histogram: heights
ax = axes[2]
ax.hist(heights, bins=30, color='lightcoral', edgecolor='black', alpha=0.7)
ax.axvline(avg_h, color='blue', linestyle='--', linewidth=2,
label=f'Mean: {avg_h:.0f}')
ax.axvline(np.median(heights), color='green', linestyle='--', linewidth=2,
label=f'Median: {np.median(heights):.0f}')
ax.set_xlabel('Height (pixels)', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title('Height Distribution', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()
# Print statistics
print(f"\nπ IMAGE DIMENSION STATISTICS:")
print(f" Width: mean={widths.mean():.0f}, median={np.median(widths):.0f}, "
f"std={widths.std():.0f}")
print(f" Height: mean={heights.mean():.0f}, median={np.median(heights):.0f}, "
f"std={heights.std():.0f}")
print(f" Aspect Ratio: mean={aspect_ratios.mean():.2f}, "
f"median={np.median(aspect_ratios):.2f}")
def create_annotation_heatmap(self, grid_size=20):
"""
Plot 3: Heatmap showing where annotations are located in images.
"""
# Create grid for heatmap (normalized coordinates 0-1)
heatmap = np.zeros((grid_size, grid_size))
for split in ['train', 'val', 'test']:
if split not in self.config:
continue
img_path, label_path = self._get_split_paths(split)
if not os.path.exists(label_path):
continue
label_files = [f for f in os.listdir(label_path) if f.endswith('.txt')]
for label_file in label_files:
label_file_path = os.path.join(label_path, label_file)
with open(label_file_path, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 5:
# YOLO format: class x_center y_center width height (normalized)
x_center = float(parts[1])
y_center = float(parts[2])
# Map to grid
grid_x = int(x_center * grid_size)
grid_y = int(y_center * grid_size)
# Clamp to valid range
grid_x = max(0, min(grid_size - 1, grid_x))
grid_y = max(0, min(grid_size - 1, grid_y))
heatmap[grid_y, grid_x] += 1
# Plot heatmap
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(heatmap, cmap='hot', interpolation='bilinear', origin='upper')
ax.set_xlabel('Normalized X Position', fontsize=12)
ax.set_ylabel('Normalized Y Position', fontsize=12)
ax.set_title('Annotation Center Heatmap (All Splits)',
fontsize=14, fontweight='bold')
# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Annotation Density', fontsize=12)
# Add grid
ax.set_xticks(np.arange(0, grid_size, grid_size//5))
ax.set_yticks(np.arange(0, grid_size, grid_size//5))
ax.set_xticklabels([f'{x/grid_size:.1f}' for x in range(0, grid_size, grid_size//5)])
ax.set_yticklabels([f'{y/grid_size:.1f}' for y in range(0, grid_size, grid_size//5)])
ax.grid(True, alpha=0.3, color='white', linewidth=1)
plt.tight_layout()
plt.show()
return heatmap
def analyze_objects_per_image(self):
"""
Analyze distribution of object counts per image.
"""
object_counts = []
for split in ['train', 'val', 'test']:
if split not in self.config:
continue
img_path, label_path = self._get_split_paths(split)
if not os.path.exists(img_path) or not os.path.exists(label_path):
continue
image_files = [f for f in os.listdir(img_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
for img_file in image_files:
label_file = os.path.splitext(img_file)[0] + '.txt'
label_file_path = os.path.join(label_path, label_file)
count = 0
if os.path.exists(label_file_path):
with open(label_file_path, 'r') as f:
count = len([line for line in f if line.strip()])
object_counts.append(count)
return np.array(object_counts)
def plot_objects_per_image(self, object_counts):
"""
Plot 4: Histogram of object count per image.
"""
fig, ax = plt.subplots(figsize=(12, 6))
max_count = object_counts.max()
bins = range(0, max_count + 2)
ax.hist(object_counts, bins=bins, edgecolor='black',
color='mediumseagreen', alpha=0.7)
# Statistics
mean_count = object_counts.mean()
median_count = np.median(object_counts)
ax.axvline(mean_count, color='red', linestyle='--', linewidth=2,
label=f'Mean: {mean_count:.2f}')
ax.axvline(median_count, color='blue', linestyle='--', linewidth=2,
label=f'Median: {median_count:.0f}')
ax.set_xlabel('Number of Objects per Image', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title('Distribution of Object Count per Image',
fontsize=14, fontweight='bold')
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()
# Print statistics
print(f"\nπ¦ OBJECTS PER IMAGE STATISTICS:")
print(f" Mean: {mean_count:.2f}")
print(f" Median: {median_count:.0f}")
print(f" Max: {max_count}")
print(f" Images with 0 objects: {np.sum(object_counts == 0)}")
print(f" Images with 1 object: {np.sum(object_counts == 1)}")
print(f" Images with 2+ objects: {np.sum(object_counts >= 2)}")
def visualize_sample_images(self, num_samples=6, split='train'):
"""
Plot 5: Sample images with bounding boxes.
"""
img_path, label_path = self._get_split_paths(split)
if not os.path.exists(img_path):
print(f"β οΈ {split} images not found")
return
image_files = [f for f in os.listdir(img_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
# Sample random images
sampled_files = np.random.choice(image_files,
min(num_samples, len(image_files)),
replace=False)
# Create grid
rows = 2
cols = 3
fig, axes = plt.subplots(rows, cols, figsize=(18, 12))
axes = axes.flatten()
for idx, img_file in enumerate(sampled_files):
if idx >= rows * cols:
break
# Load image
img_full_path = os.path.join(img_path, img_file)
img = cv2.imread(img_full_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w = img.shape[:2]
# Load annotations
label_file = os.path.splitext(img_file)[0] + '.txt'
label_file_path = os.path.join(label_path, label_file)
if os.path.exists(label_file_path):
with open(label_file_path, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 5:
cls_id = int(parts[0])
x_center = float(parts[1]) * w
y_center = float(parts[2]) * h
box_w = float(parts[3]) * w
box_h = float(parts[4]) * h
# Convert to corner coordinates
x1 = int(x_center - box_w / 2)
y1 = int(y_center - box_h / 2)
x2 = int(x_center + box_w / 2)
y2 = int(y_center + box_h / 2)
# Draw rectangle
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
# Add label
class_name = self.class_names[cls_id] if isinstance(self.class_names, list) else self.class_names[cls_id]
cv2.putText(img, class_name, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
# Display
axes[idx].imshow(img)
axes[idx].set_title(f'{split}: {img_file}', fontsize=10)
axes[idx].axis('off')
# Hide unused subplots
for idx in range(len(sampled_files), rows * cols):
axes[idx].axis('off')
plt.suptitle(f'Sample Images with Annotations ({split.upper()} split)',
fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.show()
def run_full_eda(self):
"""
Run complete EDA pipeline with all visualizations.
"""
print("="*70)
print("π YOLO DATASET EXPLORATORY DATA ANALYSIS")
print("="*70)
# 1. Dataset splits analysis
print("\nπ Analyzing dataset splits...")
splits_data = self.analyze_dataset_splits()
self.plot_split_statistics(splits_data)
# 2. Image dimensions
print("\nπ Analyzing image dimensions...")
widths, heights, aspect_ratios = self.analyze_image_dimensions()
self.plot_image_dimensions(widths, heights, aspect_ratios)
# 3. Annotation heatmap
print("\nπΊοΈ Creating annotation heatmap...")
self.create_annotation_heatmap(grid_size=20)
# 4. Objects per image
print("\nπ¦ Analyzing objects per image...")
object_counts = self.analyze_objects_per_image()
self.plot_objects_per_image(object_counts)
# 5. Sample images
print("\nπΌοΈ Visualizing sample images...")
for split in ['train', 'val']:
if split in self.config:
print(f"\n {split.upper()} samples:")
self.visualize_sample_images(num_samples=6, split=split)
print("\n" + "="*70)
print("β
EDA COMPLETE!")
print("="*70)
# =======================
# USAGE
# =======================
# Initialize EDA
yaml_path = yaml_path # Update this path
dataset_root = dataset_root # Update if needed
eda = YOLODatasetEDA(yaml_path, dataset_root)
# Run full analysis
eda.run_full_eda()
# Or run individual analyses
# splits_data = eda.analyze_dataset_splits()
# eda.plot_split_statistics(splits_data)
# widths, heights, aspect_ratios = eda.analyze_image_dimensions()
# eda.plot_image_dimensions(widths, heights, aspect_ratios)
Output hidden; open in https://colab.research.google.com to view.
"""
YOLO Dataset Exploratory Data Analysis Tool
Comprehensive visual analysis of YOLO object detection datasets
"""
import os
import yaml
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from pathlib import Path
import cv2
from collections import defaultdict
class YOLODatasetEDA:
"""
Comprehensive EDA for YOLO datasets with visualization.
"""
def __init__(self, yaml_path, dataset_root=None):
"""
Initialize EDA tool.
Parameters:
-----------
yaml_path : str
Path to data.yaml file
dataset_root : str, optional
Root directory of dataset (if yaml paths are relative)
"""
self.yaml_path = yaml_path
# Load YAML
with open(yaml_path, 'r') as f:
self.config = yaml.safe_load(f)
# Set dataset root
if dataset_root is None:
dataset_root = os.path.dirname(yaml_path)
self.dataset_root = dataset_root
# Get class names
self.class_names = self.config['names']
self.num_classes = self.config['nc']
print(f"π Dataset loaded: {self.num_classes} class(es)")
print(f" Classes: {self.class_names}")
def _get_split_paths(self, split):
"""Get image and label paths for a split."""
# Handle relative paths
img_path = self.config[split]
if img_path.startswith('..'):
img_path = os.path.join(self.dataset_root, img_path.lstrip('../'))
# Get labels path (replace /images with /labels)
label_path = img_path.replace('/images', '/labels')
return img_path, label_path
def analyze_dataset_splits(self):
"""
Analyze dataset splits: train, val, test.
Returns statistics for each split.
"""
splits_data = {}
for split in ['train', 'val', 'test']:
if split not in self.config:
continue
img_path, label_path = self._get_split_paths(split)
if not os.path.exists(img_path):
print(f"β οΈ Warning: {split} images not found at {img_path}")
continue
# Count images
image_files = [f for f in os.listdir(img_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
num_images = len(image_files)
# Count annotations
total_annotations = 0
class_counts = defaultdict(int)
images_with_annotations = 0
if os.path.exists(label_path):
for img_file in image_files:
label_file = os.path.splitext(img_file)[0] + '.txt'
label_file_path = os.path.join(label_path, label_file)
if os.path.exists(label_file_path):
with open(label_file_path, 'r') as f:
lines = f.readlines()
if lines:
images_with_annotations += 1
for line in lines:
if line.strip():
parts = line.strip().split()
if parts:
cls_id = int(parts[0])
class_counts[cls_id] += 1
total_annotations += 1
splits_data[split] = {
'num_images': num_images,
'total_annotations': total_annotations,
'images_with_annotations': images_with_annotations,
'class_counts': dict(class_counts),
'avg_annotations_per_image': total_annotations / num_images if num_images > 0 else 0
}
print(f"\n{split.upper()} split:")
print(f" Images: {num_images}")
print(f" Total annotations: {total_annotations}")
print(f" Images with annotations: {images_with_annotations}")
print(f" Avg annotations/image: {splits_data[split]['avg_annotations_per_image']:.2f}")
return splits_data
def plot_split_statistics(self, splits_data):
"""
Plot 1: Stacked bar chart - images with Santa vs without Santa per split.
"""
fig, ax = plt.subplots(figsize=(10, 7))
splits = list(splits_data.keys())
x = np.arange(len(splits))
# Get data
images_with_santa = [splits_data[s]['images_with_annotations'] for s in splits]
images_without_santa = [
splits_data[s]['num_images'] - splits_data[s]['images_with_annotations']
for s in splits
]
total_images = [splits_data[s]['num_images'] for s in splits]
# Colors
color_with = '#e74c3c' # Red for Santa images
color_without = '#95a5a6' # Gray for background/no Santa
# Stacked bars
bars1 = ax.bar(x, images_with_santa,
label='Images with Santa',
color=color_with,
edgecolor='black',
alpha=0.8,
linewidth=1.5)
bars2 = ax.bar(x, images_without_santa,
bottom=images_with_santa,
label='Images without Santa (background)',
color=color_without,
edgecolor='black',
alpha=0.8,
linewidth=1.5)
# Add value labels on bars
for i, (with_santa, without_santa, total) in enumerate(
zip(images_with_santa, images_without_santa, total_images)):
# Label for "with Santa" section
if with_santa > 0:
ax.text(i, with_santa/2,
f'{with_santa}\n({100*with_santa/total:.1f}%)',
ha='center', va='center',
fontsize=11, fontweight='bold',
color='white')
# Label for "without Santa" section
if without_santa > 0:
ax.text(i, with_santa + without_santa/2,
f'{without_santa}\n({100*without_santa/total:.1f}%)',
ha='center', va='center',
fontsize=11, fontweight='bold',
color='white')
# Total on top
ax.text(i, total, f'Total: {total}',
ha='center', va='bottom',
fontsize=12, fontweight='bold',
color='black')
# Formatting
ax.set_xlabel('Dataset Split', fontsize=13, fontweight='bold')
ax.set_ylabel('Number of Images', fontsize=13, fontweight='bold')
ax.set_title('Dataset Split Distribution: Images with/without Santa',
fontsize=15, fontweight='bold', pad=20)
ax.set_xticks(x)
ax.set_xticklabels([s.upper() for s in splits], fontsize=12, fontweight='bold')
ax.legend(fontsize=11, loc='upper right', framealpha=0.9)
ax.grid(True, alpha=0.3, axis='y')
# Set y-axis to start at 0
ax.set_ylim(bottom=0)
plt.tight_layout()
plt.show()
def analyze_image_dimensions(self):
"""
Analyze image dimensions across all splits.
"""
widths = []
heights = []
aspect_ratios = []
for split in ['train', 'val', 'test']:
if split not in self.config:
continue
img_path, _ = self._get_split_paths(split)
if not os.path.exists(img_path):
continue
image_files = [f for f in os.listdir(img_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
# Sample up to 100 images for speed
sampled_files = np.random.choice(image_files,
min(100, len(image_files)),
replace=False)
for img_file in sampled_files:
img_full_path = os.path.join(img_path, img_file)
try:
with Image.open(img_full_path) as img:
w, h = img.size
widths.append(w)
heights.append(h)
aspect_ratios.append(w / h)
except:
continue
return np.array(widths), np.array(heights), np.array(aspect_ratios)
def plot_image_dimensions(self, widths, heights, aspect_ratios):
"""
Plot 2: Image dimensions and aspect ratios.
"""
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# Scatter plot: width vs height
ax = axes[0]
scatter = ax.scatter(widths, heights, alpha=0.5, s=50, c=aspect_ratios,
cmap='viridis', edgecolors='black', linewidth=0.5)
ax.set_xlabel('Width (pixels)', fontsize=12)
ax.set_ylabel('Height (pixels)', fontsize=12)
ax.set_title('Image Dimensions Distribution', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
# Add average lines
avg_w, avg_h = widths.mean(), heights.mean()
ax.axvline(avg_w, color='red', linestyle='--', linewidth=2,
label=f'Avg W: {avg_w:.0f}')
ax.axhline(avg_h, color='blue', linestyle='--', linewidth=2,
label=f'Avg H: {avg_h:.0f}')
ax.legend()
# Colorbar
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Aspect Ratio (W/H)', fontsize=10)
# Histogram: widths
ax = axes[1]
ax.hist(widths, bins=30, color='skyblue', edgecolor='black', alpha=0.7)
ax.axvline(avg_w, color='red', linestyle='--', linewidth=2,
label=f'Mean: {avg_w:.0f}')
ax.axvline(np.median(widths), color='green', linestyle='--', linewidth=2,
label=f'Median: {np.median(widths):.0f}')
ax.set_xlabel('Width (pixels)', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title('Width Distribution', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')
# Histogram: heights
ax = axes[2]
ax.hist(heights, bins=30, color='lightcoral', edgecolor='black', alpha=0.7)
ax.axvline(avg_h, color='blue', linestyle='--', linewidth=2,
label=f'Mean: {avg_h:.0f}')
ax.axvline(np.median(heights), color='green', linestyle='--', linewidth=2,
label=f'Median: {np.median(heights):.0f}')
ax.set_xlabel('Height (pixels)', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title('Height Distribution', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()
# Print statistics
print(f"\nπ IMAGE DIMENSION STATISTICS:")
print(f" Width: mean={widths.mean():.0f}, median={np.median(widths):.0f}, "
f"std={widths.std():.0f}")
print(f" Height: mean={heights.mean():.0f}, median={np.median(heights):.0f}, "
f"std={heights.std():.0f}")
print(f" Aspect Ratio: mean={aspect_ratios.mean():.2f}, "
f"median={np.median(aspect_ratios):.2f}")
def create_annotation_heatmap(self, grid_size=20):
"""
Plot 3: Heatmap showing where annotations are located in images.
"""
# Create grid for heatmap (normalized coordinates 0-1)
heatmap = np.zeros((grid_size, grid_size))
for split in ['train', 'val', 'test']:
if split not in self.config:
continue
img_path, label_path = self._get_split_paths(split)
if not os.path.exists(label_path):
continue
label_files = [f for f in os.listdir(label_path) if f.endswith('.txt')]
for label_file in label_files:
label_file_path = os.path.join(label_path, label_file)
with open(label_file_path, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 5:
# YOLO format: class x_center y_center width height (normalized)
x_center = float(parts[1])
y_center = float(parts[2])
# Map to grid
grid_x = int(x_center * grid_size)
grid_y = int(y_center * grid_size)
# Clamp to valid range
grid_x = max(0, min(grid_size - 1, grid_x))
grid_y = max(0, min(grid_size - 1, grid_y))
heatmap[grid_y, grid_x] += 1
# Plot heatmap
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(heatmap, cmap='hot', interpolation='bilinear', origin='upper')
ax.set_xlabel('Normalized X Position', fontsize=12)
ax.set_ylabel('Normalized Y Position', fontsize=12)
ax.set_title('Annotation Center Heatmap (All Splits)',
fontsize=14, fontweight='bold')
# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Annotation Density', fontsize=12)
# Add grid
ax.set_xticks(np.arange(0, grid_size, grid_size//5))
ax.set_yticks(np.arange(0, grid_size, grid_size//5))
ax.set_xticklabels([f'{x/grid_size:.1f}' for x in range(0, grid_size, grid_size//5)])
ax.set_yticklabels([f'{y/grid_size:.1f}' for y in range(0, grid_size, grid_size//5)])
ax.grid(True, alpha=0.3, color='white', linewidth=1)
plt.tight_layout()
plt.show()
return heatmap
def analyze_objects_per_image(self):
"""
Analyze distribution of object counts per image.
"""
object_counts = []
for split in ['train', 'val', 'test']:
if split not in self.config:
continue
img_path, label_path = self._get_split_paths(split)
if not os.path.exists(img_path) or not os.path.exists(label_path):
continue
image_files = [f for f in os.listdir(img_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
for img_file in image_files:
label_file = os.path.splitext(img_file)[0] + '.txt'
label_file_path = os.path.join(label_path, label_file)
count = 0
if os.path.exists(label_file_path):
with open(label_file_path, 'r') as f:
count = len([line for line in f if line.strip()])
object_counts.append(count)
return np.array(object_counts)
def plot_objects_per_image(self, object_counts):
"""
Plot 4: Histogram of object count per image.
"""
fig, ax = plt.subplots(figsize=(12, 6))
max_count = object_counts.max()
bins = range(0, max_count + 2)
ax.hist(object_counts, bins=bins, edgecolor='black',
color='mediumseagreen', alpha=0.7)
# Statistics
mean_count = object_counts.mean()
median_count = np.median(object_counts)
ax.axvline(mean_count, color='red', linestyle='--', linewidth=2,
label=f'Mean: {mean_count:.2f}')
ax.axvline(median_count, color='blue', linestyle='--', linewidth=2,
label=f'Median: {median_count:.0f}')
ax.set_xlabel('Number of Objects per Image', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title('Distribution of Object Count per Image',
fontsize=14, fontweight='bold')
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()
# Print statistics
print(f"\nπ¦ OBJECTS PER IMAGE STATISTICS:")
print(f" Mean: {mean_count:.2f}")
print(f" Median: {median_count:.0f}")
print(f" Max: {max_count}")
print(f" Images with 0 objects: {np.sum(object_counts == 0)}")
print(f" Images with 1 object: {np.sum(object_counts == 1)}")
print(f" Images with 2+ objects: {np.sum(object_counts >= 2)}")
def visualize_sample_images(self, num_samples=6, split='train'):
"""
Plot 5: Sample images with bounding boxes.
"""
img_path, label_path = self._get_split_paths(split)
if not os.path.exists(img_path):
print(f"β οΈ {split} images not found")
return
image_files = [f for f in os.listdir(img_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
# Sample random images
sampled_files = np.random.choice(image_files,
min(num_samples, len(image_files)),
replace=False)
# Create grid
rows = 2
cols = 3
fig, axes = plt.subplots(rows, cols, figsize=(18, 12))
axes = axes.flatten()
for idx, img_file in enumerate(sampled_files):
if idx >= rows * cols:
break
# Load image
img_full_path = os.path.join(img_path, img_file)
img = cv2.imread(img_full_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w = img.shape[:2]
# Load annotations
label_file = os.path.splitext(img_file)[0] + '.txt'
label_file_path = os.path.join(label_path, label_file)
if os.path.exists(label_file_path):
with open(label_file_path, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 5:
cls_id = int(parts[0])
x_center = float(parts[1]) * w
y_center = float(parts[2]) * h
box_w = float(parts[3]) * w
box_h = float(parts[4]) * h
# Convert to corner coordinates
x1 = int(x_center - box_w / 2)
y1 = int(y_center - box_h / 2)
x2 = int(x_center + box_w / 2)
y2 = int(y_center + box_h / 2)
# Draw rectangle
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
# Add label
class_name = self.class_names[cls_id] if isinstance(self.class_names, list) else self.class_names[cls_id]
cv2.putText(img, class_name, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
# Display
axes[idx].imshow(img)
axes[idx].set_title(f'{split}: {img_file}', fontsize=10)
axes[idx].axis('off')
# Hide unused subplots
for idx in range(len(sampled_files), rows * cols):
axes[idx].axis('off')
plt.suptitle(f'Sample Images with Annotations ({split.upper()} split)',
fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.show()
def run_full_eda(self):
"""
Run complete EDA pipeline with all visualizations.
"""
print("="*70)
print("π YOLO DATASET EXPLORATORY DATA ANALYSIS")
print("="*70)
# 1. Dataset splits analysis
print("\nπ Analyzing dataset splits...")
splits_data = self.analyze_dataset_splits()
self.plot_split_statistics(splits_data)
# 2. Image dimensions
print("\nπ Analyzing image dimensions...")
widths, heights, aspect_ratios = self.analyze_image_dimensions()
self.plot_image_dimensions(widths, heights, aspect_ratios)
# 3. Annotation heatmap
print("\nπΊοΈ Creating annotation heatmap...")
self.create_annotation_heatmap(grid_size=20)
# 4. Objects per image
print("\nπ¦ Analyzing objects per image...")
object_counts = self.analyze_objects_per_image()
self.plot_objects_per_image(object_counts)
# 5. Sample images
print("\nπΌοΈ Visualizing sample images...")
for split in ['train', 'val']:
if split in self.config:
print(f"\n {split.upper()} samples:")
self.visualize_sample_images(num_samples=6, split=split)
print("\n" + "="*70)
print("β
EDA COMPLETE!")
print("="*70)
# =======================
# USAGE
# =======================
# Initialize EDA
yaml_path = yaml_path # Update this path
dataset_root = dataset_root # Update if needed
eda = YOLODatasetEDA(yaml_path, dataset_root)
# Run full analysis
eda.run_full_eda()
# Or run individual analyses
# splits_data = eda.analyze_dataset_splits()
# eda.plot_split_statistics(splits_data)
# widths, heights, aspect_ratios = eda.analyze_image_dimensions()
# eda.plot_image_dimensions(widths, heights, aspect_ratios)
Output hidden; open in https://colab.research.google.com to view.